{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7daa0a33",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker.pytorch import PyTorch\n",
    "import boto3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55b9f811",
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_client = boto3.client(\"s3\")\n",
    "sess = sagemaker.session.Session()\n",
    "role = sagemaker.get_execution_role()\n",
    "bucket = sess.default_bucket()\n",
    "key_prefix = \"pt_lightning_ddp_tune\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d740c86d",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz && tar xzf cifar-10-python.tar.gz && rm cifar-10-python.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2469eabb",
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_data_path = sess.upload_data(\"cifar-10-batches-py\", bucket, key_prefix=f\"{key_prefix}/input_data/cifar-10-batches-py\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5c448a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_path = \"/\".join(cifar_data_path.split(\"/\")[:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3030f827",
   "metadata": {},
   "outputs": [],
   "source": [
    "subnets=None\n",
    "security_group_ids=None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "392c20ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator_gpu_tune_cifar = PyTorch(\n",
    "    source_dir = \"src\",\n",
    "    entry_point=\"tune_cifar.py\",\n",
    "    subnets=subnets,\n",
    "    security_group_ids=security_group_ids,\n",
    "    role=role,\n",
    "    instance_count=4, \n",
    "    instance_type=\"ml.g5.12xlarge\",\n",
    "    framework_version=\"1.10\",\n",
    "    py_version=\"py38\",\n",
    "    hyperparameters={\"use-gpu\":True, \n",
    "                     \"num-samples\":16, \n",
    "                     \"num-workers\":2, \n",
    "                     \"num-epochs\":20}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5d70f6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator_gpu_tune_cifar.fit({\"train\": cifar_path})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66e4658d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
